package rsa

import (
	"bytes"
	"crypto/rand"
	"crypto/rsa"
	"crypto/x509"
	"encoding/base64"
	"encoding/pem"
	"log"
	"os"
	"runtime"
)

const (
	privateFileName  = "private.pem"
	publicFileName   = "public.pem"
	privateKeyPrefix = "PRIVATE KEY"
	publicKeyPrefix  = "PUBLIC KEY "
)

// GenRsaKey 生成非对称加密的密钥文件
func GenRsaKey() error {
	privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
	if err != nil {
		return err
	}
	x509PrivateKey := x509.MarshalPKCS1PrivateKey(privateKey)
	privateFile, err := os.Create(privateFileName)
	if err != nil {
		return err
	}
	defer privateFile.Close()
	privateBlock := pem.Block{
		Type:  privateKeyPrefix,
		Bytes: x509PrivateKey,
	}

	if err = pem.Encode(privateFile, &privateBlock); err != nil {
		return err
	}
	publicKey := privateKey.PublicKey
	x509PublicKey, err := x509.MarshalPKIXPublicKey(&publicKey)
	if err != nil {
		panic(err)
	}
	publicFile, _ := os.Create(publicFileName)
	defer publicFile.Close()
	publicBlock := pem.Block{
		Type:  publicKeyPrefix,
		Bytes: x509PublicKey,
	}
	if err = pem.Encode(publicFile, &publicBlock); err != nil {
		return err
	}
	return nil
}

// RsaEncrypt 非对称加密
func RsaEncrypt(plainText, pubKey []byte) (string, error) {
	defer func() {
		if err := recover(); err != nil {
			switch err.(type) {
			case runtime.Error:
				log.Println("runtime err:", err, "Check that the key is correct")
			default:
				log.Println("error:", err)
			}
		}
	}()
	publicKey, err := getPublicKey(pubKey)
	if err != nil {
		return "", err
	}

	partLen := publicKey.N.BitLen()/8 - 11
	chunks := splitRsaData([]byte(plainText), partLen)

	buffer := bytes.NewBufferString("")
	for _, chunk := range chunks {
		cipherText, err := rsa.EncryptPKCS1v15(rand.Reader, publicKey, chunk)
		if err != nil {
			return "", err
		}
		buffer.Write(cipherText)
	}

	return base64.RawURLEncoding.EncodeToString(buffer.Bytes()), nil
}

// RsaDecrypt 非对称解密
func RsaDecrypt(cryptText string, pubKey, priKey []byte) (string, error) {

	defer func() {
		if err := recover(); err != nil {
			switch err.(type) {
			case runtime.Error:
				log.Println("runtime err:", err, "Check that the key is correct")
			default:
				log.Println("error:", err)
			}
		}
	}()
	privateKey, err := getPrivateKey(priKey)
	if err != nil {
		return "", err
	}

	publicKey, err := getPublicKey(pubKey)
	if err != nil {
		return "", err
	}

	partLen := publicKey.N.BitLen() / 8
	raw, err := base64.RawURLEncoding.DecodeString(cryptText)
	if err != nil {
		return "", err
	}

	chunks := splitRsaData(raw, partLen)

	buffer := bytes.NewBufferString("")
	for _, chunk := range chunks {
		content, err := rsa.DecryptPKCS1v15(rand.Reader, privateKey, chunk)
		if err != nil {
			return "", err
		}
		buffer.Write(content)
	}

	return buffer.String(), nil
}

func getPublicKey(key []byte) (*rsa.PublicKey, error) {
	block, _ := pem.Decode(key)
	defer func() {
		if err := recover(); err != nil {
			switch err.(type) {
			case runtime.Error:
				log.Println("runtime err:", err, "Check that the key is correct")
			default:
				log.Println("error:", err)
			}
		}
	}()
	publicKeyInterface, err := x509.ParsePKIXPublicKey(block.Bytes)
	if err != nil {
		return nil, err
	}
	return publicKeyInterface.(*rsa.PublicKey), nil
}

func getPrivateKey(key []byte) (*rsa.PrivateKey, error) {
	block, _ := pem.Decode(key)
	defer func() {
		if err := recover(); err != nil {
			switch err.(type) {
			case runtime.Error:
				log.Println("runtime err:", err, "Check that the key is correct")
			default:
				log.Println("error:", err)
			}
		}
	}()

	privateKey, err := x509.ParsePKCS1PrivateKey(block.Bytes)
	if err != nil {
		return nil, err
	}
	return privateKey, nil
}

func splitRsaData(buf []byte, lim int) [][]byte {
	var chunk []byte
	chunks := make([][]byte, 0, len(buf)/lim+1)
	for len(buf) >= lim {
		chunk, buf = buf[:lim], buf[lim:]
		chunks = append(chunks, chunk)
	}
	if len(buf) > 0 {
		chunks = append(chunks, buf[:])
	}
	return chunks
}
